{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "8d0c7760",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   pixel1  pixel2  pixel3  pixel4  pixel5  pixel6  pixel7  pixel8  pixel9  \\\n",
      "0       0       0       0       0       0       0       0       0       0   \n",
      "1       0       0       0       0       0       0       0       0       0   \n",
      "2       0       0       0       0       0       0       0       0       0   \n",
      "3       0       0       0       0       0       0       0       0       0   \n",
      "4       0       0       0       0       0       0       0       0       0   \n",
      "\n",
      "   pixel10  ...  pixel775  pixel776  pixel777  pixel778  pixel779  pixel780  \\\n",
      "0        0  ...         0         0         0         0         0         0   \n",
      "1        0  ...         0         0         0         0         0         0   \n",
      "2        0  ...         0         0         0         0         0         0   \n",
      "3        0  ...         0         0         0         0         0         0   \n",
      "4        0  ...         0         0         0         0         0         0   \n",
      "\n",
      "   pixel781  pixel782  pixel783  pixel784  \n",
      "0         0         0         0         0  \n",
      "1         0         0         0         0  \n",
      "2         0         0         0         0  \n",
      "3         0         0         0         0  \n",
      "4         0         0         0         0  \n",
      "\n",
      "[5 rows x 784 columns]\n",
      "(70000, 784)\n",
      "均方误差: 1.4565\n",
      "R2 分数: 0.8260080236159102\n",
      "模型准确率: 0.9154285714285715\n"
     ]
    }
   ],
   "source": [
    "from sklearn.datasets import fetch_openml\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.linear_model import LogisticRegression  # 逻辑回归模型\n",
    "from sklearn.model_selection import train_test_split # 训练集测试集划分\n",
    "from sklearn.metrics import mean_squared_error, r2_score\n",
    "import pandas as pd\n",
    "\n",
    "# 载入 MNIST 手写数字图片数据集\n",
    "mnist = fetch_openml(\"mnist_784\")\n",
    "\n",
    "df = pd.DataFrame(mnist.data)\n",
    "print(df.head())\n",
    "\n",
    "# 数据预处理，标准归一化\n",
    "scaler = StandardScaler()\n",
    "X = scaler.fit_transform(mnist.data)\n",
    "print(X.shape)\n",
    "\n",
    "# 划分数据集和测试集\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, mnist.target, test_size=0.2, random_state=42)\n",
    "\n",
    "# 创建逻辑回归模型\n",
    "model = LogisticRegression(max_iter=1000)\n",
    "\n",
    "# 在训练集上训练模型\n",
    "model.fit(X_train, y_train)\n",
    "\n",
    "# 在测试集上进行预测\n",
    "y_pred = model.predict(X_test)\n",
    "\n",
    "# 输出模型的准确率\n",
    "mse = mean_squared_error(y_test, y_pred)\n",
    "r2 = r2_score(y_test, y_pred)\n",
    "\n",
    "print(\"均方误差:\", mse)\n",
    "print(\"R2 分数:\", r2)\n",
    "print(\"模型准确率:\", model.score(X_test, y_test))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "46056057",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 70000 entries, 0 to 69999\n",
      "Columns: 784 entries, pixel1 to pixel784\n",
      "dtypes: int64(784)\n",
      "memory usage: 418.7 MB\n"
     ]
    }
   ],
   "source": [
    "df.info()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "1de08e8d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>pixel1</th>\n",
       "      <th>pixel2</th>\n",
       "      <th>pixel3</th>\n",
       "      <th>pixel4</th>\n",
       "      <th>pixel5</th>\n",
       "      <th>pixel6</th>\n",
       "      <th>pixel7</th>\n",
       "      <th>pixel8</th>\n",
       "      <th>pixel9</th>\n",
       "      <th>pixel10</th>\n",
       "      <th>...</th>\n",
       "      <th>pixel775</th>\n",
       "      <th>pixel776</th>\n",
       "      <th>pixel777</th>\n",
       "      <th>pixel778</th>\n",
       "      <th>pixel779</th>\n",
       "      <th>pixel780</th>\n",
       "      <th>pixel781</th>\n",
       "      <th>pixel782</th>\n",
       "      <th>pixel783</th>\n",
       "      <th>pixel784</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>70000.0</td>\n",
       "      <td>70000.0</td>\n",
       "      <td>70000.0</td>\n",
       "      <td>70000.0</td>\n",
       "      <td>70000.0</td>\n",
       "      <td>70000.0</td>\n",
       "      <td>70000.0</td>\n",
       "      <td>70000.0</td>\n",
       "      <td>70000.0</td>\n",
       "      <td>70000.0</td>\n",
       "      <td>...</td>\n",
       "      <td>70000.000000</td>\n",
       "      <td>70000.000000</td>\n",
       "      <td>70000.000000</td>\n",
       "      <td>70000.000000</td>\n",
       "      <td>70000.000000</td>\n",
       "      <td>70000.000000</td>\n",
       "      <td>70000.0</td>\n",
       "      <td>70000.0</td>\n",
       "      <td>70000.0</td>\n",
       "      <td>70000.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.197414</td>\n",
       "      <td>0.099543</td>\n",
       "      <td>0.046629</td>\n",
       "      <td>0.016614</td>\n",
       "      <td>0.012957</td>\n",
       "      <td>0.001714</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>5.991206</td>\n",
       "      <td>4.256304</td>\n",
       "      <td>2.783732</td>\n",
       "      <td>1.561822</td>\n",
       "      <td>1.553796</td>\n",
       "      <td>0.320889</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>254.000000</td>\n",
       "      <td>254.000000</td>\n",
       "      <td>253.000000</td>\n",
       "      <td>253.000000</td>\n",
       "      <td>254.000000</td>\n",
       "      <td>62.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>8 rows × 784 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        pixel1   pixel2   pixel3   pixel4   pixel5   pixel6   pixel7   pixel8  \\\n",
       "count  70000.0  70000.0  70000.0  70000.0  70000.0  70000.0  70000.0  70000.0   \n",
       "mean       0.0      0.0      0.0      0.0      0.0      0.0      0.0      0.0   \n",
       "std        0.0      0.0      0.0      0.0      0.0      0.0      0.0      0.0   \n",
       "min        0.0      0.0      0.0      0.0      0.0      0.0      0.0      0.0   \n",
       "25%        0.0      0.0      0.0      0.0      0.0      0.0      0.0      0.0   \n",
       "50%        0.0      0.0      0.0      0.0      0.0      0.0      0.0      0.0   \n",
       "75%        0.0      0.0      0.0      0.0      0.0      0.0      0.0      0.0   \n",
       "max        0.0      0.0      0.0      0.0      0.0      0.0      0.0      0.0   \n",
       "\n",
       "        pixel9  pixel10  ...      pixel775      pixel776      pixel777  \\\n",
       "count  70000.0  70000.0  ...  70000.000000  70000.000000  70000.000000   \n",
       "mean       0.0      0.0  ...      0.197414      0.099543      0.046629   \n",
       "std        0.0      0.0  ...      5.991206      4.256304      2.783732   \n",
       "min        0.0      0.0  ...      0.000000      0.000000      0.000000   \n",
       "25%        0.0      0.0  ...      0.000000      0.000000      0.000000   \n",
       "50%        0.0      0.0  ...      0.000000      0.000000      0.000000   \n",
       "75%        0.0      0.0  ...      0.000000      0.000000      0.000000   \n",
       "max        0.0      0.0  ...    254.000000    254.000000    253.000000   \n",
       "\n",
       "           pixel778      pixel779      pixel780  pixel781  pixel782  pixel783  \\\n",
       "count  70000.000000  70000.000000  70000.000000   70000.0   70000.0   70000.0   \n",
       "mean       0.016614      0.012957      0.001714       0.0       0.0       0.0   \n",
       "std        1.561822      1.553796      0.320889       0.0       0.0       0.0   \n",
       "min        0.000000      0.000000      0.000000       0.0       0.0       0.0   \n",
       "25%        0.000000      0.000000      0.000000       0.0       0.0       0.0   \n",
       "50%        0.000000      0.000000      0.000000       0.0       0.0       0.0   \n",
       "75%        0.000000      0.000000      0.000000       0.0       0.0       0.0   \n",
       "max      253.000000    254.000000     62.000000       0.0       0.0       0.0   \n",
       "\n",
       "       pixel784  \n",
       "count   70000.0  \n",
       "mean        0.0  \n",
       "std         0.0  \n",
       "min         0.0  \n",
       "25%         0.0  \n",
       "50%         0.0  \n",
       "75%         0.0  \n",
       "max         0.0  \n",
       "\n",
       "[8 rows x 784 columns]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "389eb7cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "from torch.utils.data import DataLoader\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "687fedc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "transform = transforms.Compose([\n",
    "    transforms.Resize((28, 28)),                  # 调整图像大小为28x28\n",
    "    transforms.ToTensor(),                        # 转换为Tensor\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "b13ae2a9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 9.91M/9.91M [00:03<00:00, 3.00MB/s]\n",
      "100%|██████████| 28.9k/28.9k [00:00<00:00, 593kB/s]\n",
      "100%|██████████| 1.65M/1.65M [00:01<00:00, 957kB/s] \n",
      "100%|██████████| 4.54k/4.54k [00:00<00:00, 4.61MB/s]\n"
     ]
    }
   ],
   "source": [
    "# 加载训练集数据\n",
    "train_dataset = datasets.MNIST(\n",
    "    root='./data',  # 数据存储的路径\n",
    "    train=True,  # 表示加载训练数据集\n",
    "    download=True,  # 如果数据集没有下载，会自动下载\n",
    "    transform=transform  # 对数据应用transform操作，例如转换为Tensor\n",
    ")\n",
    "\n",
    "# 加载测试集数据，MNIST数据集的测试部分，train=False表示加载测试集\n",
    "test_dataset = datasets.MNIST(\n",
    "    root='./data',  # 数据存储的路径\n",
    "    train=False,  # 表示加载测试数据集\n",
    "    download=True,  # 如果数据集没有下载，会自动下载\n",
    "    transform=transform  # 对数据应用transform操作，例如转换为Tensor\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "6526a3f1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaqElEQVR4nO3df2xV9f3H8VeL9ILaXiylvb2jQEEFwy8ng9rwYygNtC4GtEtA/QMWAoFdzLDzx7qIKFvSjSWOuCD+s8BMxF+JQCRLMym2hNliqDDCph3tugGBFsVxbylSGP18/yDer1cKeMq9ffdeno/kJPTe8+l9ezzhyWlvT9Occ04AAPSxdOsBAAA3JwIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBM3GI9wLd1d3frxIkTyszMVFpamvU4AACPnHPq6OhQMBhUevrVr3P6XYBOnDihgoIC6zEAADfo2LFjGj58+FWf73dfgsvMzLQeAQAQB9f7+zxhAdq4caNGjRqlQYMGqaioSB9//PF3WseX3QAgNVzv7/OEBOjtt99WRUWF1q5dq08++USTJ0/WvHnzdOrUqUS8HAAgGbkEmDZtmguFQtGPL1265ILBoKuqqrru2nA47CSxsbGxsSX5Fg6Hr/n3fdyvgC5cuKDGxkaVlJREH0tPT1dJSYnq6+uv2L+rq0uRSCRmAwCkvrgH6IsvvtClS5eUl5cX83heXp7a2tqu2L+qqkp+vz+68Q44ALg5mL8LrrKyUuFwOLodO3bMeiQAQB+I+88B5eTkaMCAAWpvb495vL29XYFA4Ir9fT6ffD5fvMcAAPRzcb8CysjI0JQpU1RTUxN9rLu7WzU1NSouLo73ywEAklRC7oRQUVGhxYsX6wc/+IGmTZumDRs2qLOzUz/5yU8S8XIAgCSUkAAtXLhQn3/+uV544QW1tbXp3nvvVXV19RVvTAAA3LzSnHPOeohvikQi8vv91mMAAG5QOBxWVlbWVZ83fxccAODmRIAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATMQ9QC+++KLS0tJitnHjxsX7ZQAASe6WRHzS8ePHa9euXf//Irck5GUAAEksIWW45ZZbFAgEEvGpAQApIiHfAzpy5IiCwaBGjx6tJ554QkePHr3qvl1dXYpEIjEbACD1xT1ARUVF2rJli6qrq7Vp0ya1trZq5syZ6ujo6HH/qqoq+f3+6FZQUBDvkQAA/VCac84l8gXOnDmjkSNH6uWXX9bSpUuveL6rq0tdXV3RjyORCBECgBQQDoeVlZV11ecT/u6AIUOG6O6771Zzc3OPz/t8Pvl8vkSPAQDoZxL+c0Bnz55VS0uL8vPzE/1SAIAkEvcAPf3006qrq9O///1vffTRR3rkkUc0YMAAPfbYY/F+KQBAEov7l+COHz+uxx57TKdPn9awYcM0Y8YMNTQ0aNiwYfF+KQBAEkv4mxC8ikQi8vv91mMAAG7Q9d6EwL3gAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATCf+FdOhbP/7xjz2vWbZsWa9e68SJE57XnD9/3vOaN954w/OatrY2z2skXfUXJwKIP66AAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYCLNOeesh/imSCQiv99vPUbS+te//uV5zahRo+I/iLGOjo5erfv73/8e50kQb8ePH/e8Zv369b16rf379/dqHS4Lh8PKysq66vNcAQEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJm6xHgDxtWzZMs9rJk2a1KvX+vTTTz2vueeeezyvue+++zyvmT17tuc1knT//fd7XnPs2DHPawoKCjyv6Uv/+9//PK/5/PPPPa/Jz8/3vKY3jh492qt13Iw0sbgCAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMcDPSFFNTU9Mna3qrurq6T17njjvu6NW6e++91/OaxsZGz2umTp3qeU1fOn/+vOc1//znPz2v6c0NbbOzsz2vaWlp8bwGiccVEADABAECAJjwHKA9e/bo4YcfVjAYVFpamrZv3x7zvHNOL7zwgvLz8zV48GCVlJToyJEj8ZoXAJAiPAeos7NTkydP1saNG3t8fv369XrllVf02muvad++fbrttts0b968Xn1NGQCQujy/CaGsrExlZWU9Puec04YNG/T8889r/vz5kqTXX39deXl52r59uxYtWnRj0wIAUkZcvwfU2tqqtrY2lZSURB/z+/0qKipSfX19j2u6uroUiURiNgBA6otrgNra2iRJeXl5MY/n5eVFn/u2qqoq+f3+6FZQUBDPkQAA/ZT5u+AqKysVDoej27Fjx6xHAgD0gbgGKBAISJLa29tjHm9vb48+920+n09ZWVkxGwAg9cU1QIWFhQoEAjE/WR+JRLRv3z4VFxfH86UAAEnO87vgzp49q+bm5ujHra2tOnjwoLKzszVixAitXr1av/71r3XXXXepsLBQa9asUTAY1IIFC+I5NwAgyXkO0P79+/XAAw9EP66oqJAkLV68WFu2bNGzzz6rzs5OLV++XGfOnNGMGTNUXV2tQYMGxW9qAEDSS3POOeshvikSicjv91uPAcCj8vJyz2veeecdz2sOHz7sec03/9HsxZdfftmrdbgsHA5f8/v65u+CAwDcnAgQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGDC869jAJD6cnNzPa959dVXPa9JT/f+b+B169Z5XsNdrfsnroAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABPcjBTAFUKhkOc1w4YN87zmv//9r+c1TU1Nntegf+IKCABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwwc1IgRQ2ffr0Xq37xS9+EedJerZgwQLPaw4fPhz/QWCCKyAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQ3IwVS2EMPPdSrdQMHDvS8pqamxvOa+vp6z2uQOrgCAgCYIEAAABOeA7Rnzx49/PDDCgaDSktL0/bt22OeX7JkidLS0mK20tLSeM0LAEgRngPU2dmpyZMna+PGjVfdp7S0VCdPnoxub7755g0NCQBIPZ7fhFBWVqaysrJr7uPz+RQIBHo9FAAg9SXke0C1tbXKzc3V2LFjtXLlSp0+ffqq+3Z1dSkSicRsAIDUF/cAlZaW6vXXX1dNTY1++9vfqq6uTmVlZbp06VKP+1dVVcnv90e3goKCeI8EAOiH4v5zQIsWLYr+eeLEiZo0aZLGjBmj2tpazZkz54r9KysrVVFREf04EokQIQC4CST8bdijR49WTk6Ompube3ze5/MpKysrZgMApL6EB+j48eM6ffq08vPzE/1SAIAk4vlLcGfPno25mmltbdXBgweVnZ2t7OxsvfTSSyovL1cgEFBLS4ueffZZ3XnnnZo3b15cBwcAJDfPAdq/f78eeOCB6Mdff/9m8eLF2rRpkw4dOqQ//elPOnPmjILBoObOnatf/epX8vl88ZsaAJD00pxzznqIb4pEIvL7/dZjAP3O4MGDPa/Zu3dvr15r/Pjxntc8+OCDntd89NFHntcgeYTD4Wt+X597wQEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMBE3H8lN4DEeOaZZzyv+f73v9+r16qurva8hjtbwyuugAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAE9yMFDDwox/9yPOaNWvWeF4TiUQ8r5GkdevW9Wod4AVXQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACW5GCtygoUOHel7zyiuveF4zYMAAz2v+/Oc/e14jSQ0NDb1aB3jBFRAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIKbkQLf0JsbflZXV3teU1hY6HlNS0uL5zVr1qzxvAboK1wBAQBMECAAgAlPAaqqqtLUqVOVmZmp3NxcLViwQE1NTTH7nD9/XqFQSEOHDtXtt9+u8vJytbe3x3VoAEDy8xSguro6hUIhNTQ06IMPPtDFixc1d+5cdXZ2Rvd56qmn9P777+vdd99VXV2dTpw4oUcffTTugwMAkpunNyF8+5utW7ZsUW5urhobGzVr1iyFw2H98Y9/1NatW/Xggw9KkjZv3qx77rlHDQ0Nuv/+++M3OQAgqd3Q94DC4bAkKTs7W5LU2NioixcvqqSkJLrPuHHjNGLECNXX1/f4Obq6uhSJRGI2AEDq63WAuru7tXr1ak2fPl0TJkyQJLW1tSkjI0NDhgyJ2TcvL09tbW09fp6qqir5/f7oVlBQ0NuRAABJpNcBCoVCOnz4sN56660bGqCyslLhcDi6HTt27IY+HwAgOfTqB1FXrVqlnTt3as+ePRo+fHj08UAgoAsXLujMmTMxV0Ht7e0KBAI9fi6fzyefz9ebMQAASczTFZBzTqtWrdK2bdu0e/fuK36ae8qUKRo4cKBqamqijzU1Neno0aMqLi6Oz8QAgJTg6QooFApp69at2rFjhzIzM6Pf1/H7/Ro8eLD8fr+WLl2qiooKZWdnKysrS08++aSKi4t5BxwAIIanAG3atEmSNHv27JjHN2/erCVLlkiSfv/73ys9PV3l5eXq6urSvHnz9Oqrr8ZlWABA6khzzjnrIb4pEonI7/dbj4Gb1N133+15zWeffZaASa40f/58z2vef//9BEwCfDfhcFhZWVlXfZ57wQEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMBEr34jKtDfjRw5slfr/vKXv8R5kp4988wzntfs3LkzAZMAdrgCAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMcDNSpKTly5f3at2IESPiPEnP6urqPK9xziVgEsAOV0AAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAluRop+b8aMGZ7XPPnkkwmYBEA8cQUEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJjgZqTo92bOnOl5ze23356ASXrW0tLiec3Zs2cTMAmQXLgCAgCYIEAAABOeAlRVVaWpU6cqMzNTubm5WrBggZqammL2mT17ttLS0mK2FStWxHVoAEDy8xSguro6hUIhNTQ06IMPPtDFixc1d+5cdXZ2xuy3bNkynTx5MrqtX78+rkMDAJKfpzchVFdXx3y8ZcsW5ebmqrGxUbNmzYo+fuuttyoQCMRnQgBASrqh7wGFw2FJUnZ2dszjb7zxhnJycjRhwgRVVlbq3LlzV/0cXV1dikQiMRsAIPX1+m3Y3d3dWr16taZPn64JEyZEH3/88cc1cuRIBYNBHTp0SM8995yampr03nvv9fh5qqqq9NJLL/V2DABAkup1gEKhkA4fPqy9e/fGPL58+fLonydOnKj8/HzNmTNHLS0tGjNmzBWfp7KyUhUVFdGPI5GICgoKejsWACBJ9CpAq1at0s6dO7Vnzx4NHz78mvsWFRVJkpqbm3sMkM/nk8/n680YAIAk5ilAzjk9+eST2rZtm2pra1VYWHjdNQcPHpQk5efn92pAAEBq8hSgUCikrVu3aseOHcrMzFRbW5skye/3a/DgwWppadHWrVv10EMPaejQoTp06JCeeuopzZo1S5MmTUrIfwAAIDl5CtCmTZskXf5h02/avHmzlixZooyMDO3atUsbNmxQZ2enCgoKVF5erueffz5uAwMAUoPnL8FdS0FBgerq6m5oIADAzYG7YQPf8Le//c3zmjlz5nhe8+WXX3peA6QabkYKADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJhIc9e7xXUfi0Qi8vv91mMAAG5QOBxWVlbWVZ/nCggAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAICJfhegfnZrOgBAL13v7/N+F6COjg7rEQAAcXC9v8/73d2wu7u7deLECWVmZiotLS3muUgkooKCAh07duyad1hNdRyHyzgOl3EcLuM4XNYfjoNzTh0dHQoGg0pPv/p1zi19ONN3kp6eruHDh19zn6ysrJv6BPsax+EyjsNlHIfLOA6XWR+H7/Jrdfrdl+AAADcHAgQAMJFUAfL5fFq7dq18Pp/1KKY4DpdxHC7jOFzGcbgsmY5Dv3sTAgDg5pBUV0AAgNRBgAAAJggQAMAEAQIAmEiaAG3cuFGjRo3SoEGDVFRUpI8//th6pD734osvKi0tLWYbN26c9VgJt2fPHj388MMKBoNKS0vT9u3bY553zumFF15Qfn6+Bg8erJKSEh05csRm2AS63nFYsmTJFedHaWmpzbAJUlVVpalTpyozM1O5ublasGCBmpqaYvY5f/68QqGQhg4dqttvv13l5eVqb283mjgxvstxmD179hXnw4oVK4wm7llSBOjtt99WRUWF1q5dq08++USTJ0/WvHnzdOrUKevR+tz48eN18uTJ6LZ3717rkRKus7NTkydP1saNG3t8fv369XrllVf02muvad++fbrttts0b948nT9/vo8nTazrHQdJKi0tjTk/3nzzzT6cMPHq6uoUCoXU0NCgDz74QBcvXtTcuXPV2dkZ3eepp57S+++/r3fffVd1dXU6ceKEHn30UcOp4++7HAdJWrZsWcz5sH79eqOJr8IlgWnTprlQKBT9+NKlSy4YDLqqqirDqfre2rVr3eTJk63HMCXJbdu2Lfpxd3e3CwQC7ne/+130sTNnzjifz+fefPNNgwn7xrePg3POLV682M2fP99kHiunTp1yklxdXZ1z7vL/+4EDB7p33303us+nn37qJLn6+nqrMRPu28fBOed++MMfup/97Gd2Q30H/f4K6MKFC2psbFRJSUn0sfT0dJWUlKi+vt5wMhtHjhxRMBjU6NGj9cQTT+jo0aPWI5lqbW1VW1tbzPnh9/tVVFR0U54ftbW1ys3N1dixY7Vy5UqdPn3aeqSECofDkqTs7GxJUmNjoy5evBhzPowbN04jRoxI6fPh28fha2+88YZycnI0YcIEVVZW6ty5cxbjXVW/uxnpt33xxRe6dOmS8vLyYh7Py8vTZ599ZjSVjaKiIm3ZskVjx47VyZMn9dJLL2nmzJk6fPiwMjMzrccz0dbWJkk9nh9fP3ezKC0t1aOPPqrCwkK1tLTol7/8pcrKylRfX68BAwZYjxd33d3dWr16taZPn64JEyZIunw+ZGRkaMiQITH7pvL50NNxkKTHH39cI0eOVDAY1KFDh/Tcc8+pqalJ7733nuG0sfp9gPD/ysrKon+eNGmSioqKNHLkSL3zzjtaunSp4WToDxYtWhT988SJEzVp0iSNGTNGtbW1mjNnjuFkiREKhXT48OGb4vug13K147B8+fLonydOnKj8/HzNmTNHLS0tGjNmTF+P2aN+/yW4nJwcDRgw4Ip3sbS3tysQCBhN1T8MGTJEd999t5qbm61HMfP1OcD5caXRo0crJycnJc+PVatWaefOnfrwww9jfn1LIBDQhQsXdObMmZj9U/V8uNpx6ElRUZEk9avzod8HKCMjQ1OmTFFNTU30se7ubtXU1Ki4uNhwMntnz55VS0uL8vPzrUcxU1hYqEAgEHN+RCIR7du376Y/P44fP67Tp0+n1PnhnNOqVau0bds27d69W4WFhTHPT5kyRQMHDow5H5qamnT06NGUOh+udxx6cvDgQUnqX+eD9bsgvou33nrL+Xw+t2XLFvePf/zDLV++3A0ZMsS1tbVZj9anfv7zn7va2lrX2trq/vrXv7qSkhKXk5PjTp06ZT1aQnV0dLgDBw64AwcOOEnu5ZdfdgcOHHD/+c9/nHPO/eY3v3FDhgxxO3bscIcOHXLz5893hYWF7quvvjKePL6udRw6Ojrc008/7err611ra6vbtWuXu++++9xdd93lzp8/bz163KxcudL5/X5XW1vrTp48Gd3OnTsX3WfFihVuxIgRbvfu3W7//v2uuLjYFRcXG04df9c7Ds3NzW7dunVu//79rrW11e3YscONHj3azZo1y3jyWEkRIOec+8Mf/uBGjBjhMjIy3LRp01xDQ4P1SH1u4cKFLj8/32VkZLjvfe97buHCha65udl6rIT78MMPnaQrtsWLFzvnLr8Ve82aNS4vL8/5fD43Z84c19TUZDt0AlzrOJw7d87NnTvXDRs2zA0cONCNHDnSLVu2LOX+kdbTf78kt3nz5ug+X331lfvpT3/q7rjjDnfrrbe6Rx55xJ08edJu6AS43nE4evSomzVrlsvOznY+n8/deeed7plnnnHhcNh28G/h1zEAAEz0++8BAQBSEwECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABg4v8AjVqFRqQZEfIAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "img, label = test_dataset[0]\n",
    "img = img.permute(1, 2, 0)  #  Tensor，转换成 (H, W, C) 格式\n",
    "plt.imshow(img, cmap='gray') \n",
    "print(label)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "681b34db",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# 使用DataLoader来加载训练数据\n",
    "\n",
    "train_loader = DataLoader(\n",
    "    train_dataset,  # 传入训练数据集\n",
    "    batch_size=32,  # 批次数量\n",
    "    shuffle=True,  # 打乱数据\n",
    "    num_workers=8  # 加载数据进程数\n",
    ")\n",
    "\n",
    "test_loader = DataLoader(\n",
    "    test_dataset,  # 传入测试数据集\n",
    "    batch_size=32,  # 每批次数量\n",
    "    shuffle=False,  # 不打乱数据\n",
    "    num_workers=8  # 加载数据进程数\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "e8ffa20d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# 检查是否有可用的GPU，如果有就使用GPU，否则使用CPU\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "8c474cc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义LeNet-5网络结构\n",
    "class LeNet5(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(LeNet5, self).__init__()\n",
    "        \n",
    "        # 第一层卷积层：输入1通道，输出6通道，卷积核5x5，padding=same\n",
    "        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)\n",
    "        # 第一层池化层：2x2平均池化\n",
    "        self.pool1 = nn.AvgPool2d(kernel_size=2)\n",
    "        # 第二层卷积层：输入6通道，输出16通道，卷积核5x5\n",
    "        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)\n",
    "        # 第二层池化层：2x2平均池化\n",
    "        self.pool2 = nn.AvgPool2d(kernel_size=2)\n",
    "        # 第三层卷积层：输入16通道，输出120通道，卷积核5x5\n",
    "        self.conv3 = nn.Conv2d(16, 120, kernel_size=5)\n",
    "        # 展平层\n",
    "        self.flatten = nn.Flatten()\n",
    "        # 全连接层：输入120，输出84\n",
    "        self.fc1 = nn.Linear(120, 84)\n",
    "        # 全连接层：输入84，输出10（对应10个类别）\n",
    "        self.fc2 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # 第一层卷积 + 激活 + 池化\n",
    "        x = torch.sigmoid(self.conv1(x))\n",
    "        x = self.pool1(x)\n",
    "        # 第二层卷积 + 激活 + 池化\n",
    "        x = torch.sigmoid(self.conv2(x))\n",
    "        x = self.pool2(x)\n",
    "        # 第三层卷积 + 激活\n",
    "        x = torch.sigmoid(self.conv3(x))\n",
    "        # 展平\n",
    "        x = self.flatten(x)  # 使用 flatten 层\n",
    "        # 全连接层 + 激活\n",
    "        x = torch.sigmoid(self.fc1(x))\n",
    "        # 输出层\n",
    "        x = self.fc2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "a3e69f0d",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "model = LeNet5().to(device)# 初始化模型，并将模型移至 GPU（如果可用）\n",
    "criterion = nn.CrossEntropyLoss()  # 分类任务使用交叉熵损失\n",
    "optimizer = optim.Adam(model.parameters(), lr=0.001) #设置优化器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "e5a6d258",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_accuracy(loader, device):\n",
    "    correct = 0  # 初始化正确预测的数\n",
    "    total = 0    # 初始化总的样本数\n",
    "    with torch.no_grad():  # 不需计算梯度，这样可以节省内存和计算资源\n",
    "        for inputs, labels in loader:# 遍历数据加载器中的每一个 batch\n",
    "            inputs, labels = inputs.to(device), labels.to(device)  # 将输入数据和标签转移到指定的设备\n",
    "            outputs = model(inputs)  # 用当前的模型进行前向传播，得到预测结果\n",
    "            max_values, predicted = torch.max(outputs, 1)  # 对每个输出的预测结果，找到最大值对应的索引，返回预测的类标签\n",
    "            total += labels.size(0)  # 更新总样本数， labels.size(0) 是当前 batch 的样本数\n",
    "            correct += (predicted == labels).sum().item()  # 统计预测正确的样本\n",
    "    accuracy = 100 * correct / total  # 计算准确率\n",
    "    return accuracy  # 返回计算得到的准确率"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "46ba8fc5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/5], Loss: 0.8267\n",
      "Train Accuracy: 92.50%, Test Accuracy: 92.59%\n",
      "Epoch [2/5], Loss: 0.1962\n",
      "Train Accuracy: 95.50%, Test Accuracy: 95.65%\n",
      "Epoch [3/5], Loss: 0.1319\n",
      "Train Accuracy: 96.64%, Test Accuracy: 96.64%\n",
      "Epoch [4/5], Loss: 0.0990\n",
      "Train Accuracy: 97.11%, Test Accuracy: 96.99%\n",
      "Epoch [5/5], Loss: 0.0788\n",
      "Train Accuracy: 98.04%, Test Accuracy: 97.94%\n"
     ]
    }
   ],
   "source": [
    "num_epochs = 5  # 设置训练的总轮数\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    model.train()  # 设置为训练模式\n",
    "    total_loss = 0  # 记录当前 epoch 的总损失\n",
    "    \n",
    "    for inputs, labels in train_loader:\n",
    "        inputs, labels = inputs.to(device), labels.to(device)\n",
    "        optimizer.zero_grad()  # 清除之前计算的梯度\n",
    "        outputs = model(inputs)  # 进行前向传播\n",
    "        loss = criterion(outputs, labels)  # 计算损失\n",
    "        loss.backward()  # 反向传播计算梯度\n",
    "        optimizer.step()  # 更新参数\n",
    "        total_loss += loss.item()  # 累加损失\n",
    "    \n",
    "    # 打印损失\n",
    "    print(f\"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss / len(train_loader):.4f}\")\n",
    "    # 打印训练集和测试集的准确率\n",
    "    train_accuracy = compute_accuracy(train_loader, device)\n",
    "    test_accuracy = compute_accuracy(test_loader, device)\n",
    "    print(f\"Train Accuracy: {train_accuracy:.2f}%, Test Accuracy: {test_accuracy:.2f}%\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "c4358618",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "模型已保存为 lenet5.pth\n"
     ]
    }
   ],
   "source": [
    "\n",
    "torch.save(model.state_dict(), 'lenet5.pth') #保存模型\n",
    "print(\"模型已保存为 lenet5.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b1996af",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "857f36ee",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
